{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c459f181",
   "metadata": {},
   "source": [
    "# AUC multiclass computation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72423baf",
   "metadata": {},
   "source": [
    "## Initial imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "48a9d90d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.optim import SGD, lr_scheduler\n",
    "\n",
    "from pytorch_widedeep import Trainer\n",
    "from pytorch_widedeep.preprocessing import TabPreprocessor\n",
    "from pytorch_widedeep.models import TabMlp, WideDeep\n",
    "from torchmetrics import AUROC\n",
    "from pytorch_widedeep.initializers import XavierNormal\n",
    "from pytorch_widedeep.datasets import load_ecoli\n",
    "from pytorch_widedeep.utils import LabelEncoder\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# increase displayed columns in jupyter notebook\n",
    "pd.set_option(\"display.max_columns\", 200)\n",
    "pd.set_option(\"display.max_rows\", 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "32df29aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SequenceName</th>\n",
       "      <th>mcg</th>\n",
       "      <th>gvh</th>\n",
       "      <th>lip</th>\n",
       "      <th>chg</th>\n",
       "      <th>aac</th>\n",
       "      <th>alm1</th>\n",
       "      <th>alm2</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>AAT_ECOLI</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.29</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.56</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.35</td>\n",
       "      <td>cp</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ACEA_ECOLI</td>\n",
       "      <td>0.07</td>\n",
       "      <td>0.40</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.54</td>\n",
       "      <td>0.35</td>\n",
       "      <td>0.44</td>\n",
       "      <td>cp</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>ACEK_ECOLI</td>\n",
       "      <td>0.56</td>\n",
       "      <td>0.40</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.37</td>\n",
       "      <td>0.46</td>\n",
       "      <td>cp</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>ACKA_ECOLI</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.52</td>\n",
       "      <td>0.45</td>\n",
       "      <td>0.36</td>\n",
       "      <td>cp</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>ADI_ECOLI</td>\n",
       "      <td>0.23</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.48</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.55</td>\n",
       "      <td>0.25</td>\n",
       "      <td>0.35</td>\n",
       "      <td>cp</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  SequenceName   mcg   gvh   lip  chg   aac  alm1  alm2 class\n",
       "0    AAT_ECOLI  0.49  0.29  0.48  0.5  0.56  0.24  0.35    cp\n",
       "1   ACEA_ECOLI  0.07  0.40  0.48  0.5  0.54  0.35  0.44    cp\n",
       "2   ACEK_ECOLI  0.56  0.40  0.48  0.5  0.49  0.37  0.46    cp\n",
       "3   ACKA_ECOLI  0.59  0.49  0.48  0.5  0.52  0.45  0.36    cp\n",
       "4    ADI_ECOLI  0.23  0.32  0.48  0.5  0.55  0.25  0.35    cp"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = load_ecoli(as_frame=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3e344836",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "class\n",
       "cp     143\n",
       "im      77\n",
       "pp      52\n",
       "imU     35\n",
       "om      20\n",
       "omL      5\n",
       "imS      2\n",
       "imL      2\n",
       "Name: count, dtype: int64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# imbalance of the classes\n",
    "df[\"class\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6ba3fc9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.loc[~df[\"class\"].isin([\"omL\", \"imS\", \"imL\"])]\n",
    "df.reset_index(inplace=True, drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8753d67d",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = LabelEncoder([\"class\"])\n",
    "df_enc = encoder.fit_transform(df)\n",
    "df_enc[\"class\"] = df_enc[\"class\"] - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "af1fbef5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop columns we won't need in this example\n",
    "df_enc = df_enc.drop(columns=[\"SequenceName\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4de303c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train, df_valid = train_test_split(\n",
    "    df_enc, test_size=0.2, stratify=df_enc[\"class\"], random_state=1\n",
    ")\n",
    "df_valid, df_test = train_test_split(\n",
    "    df_valid, test_size=0.5, stratify=df_valid[\"class\"], random_state=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fbcca37",
   "metadata": {},
   "source": [
    "## Preparing the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "31be8a96",
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous_cols = df_enc.drop(columns=[\"class\"]).columns.values.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f352329",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:295: DeprecationWarning: 'scale' and 'already_standard' will be deprecated in the next release. Please use 'cols_to_scale' instead\n",
      "  self._check_inputs(cat_embed_cols)\n"
     ]
    }
   ],
   "source": [
    "# deeptabular\n",
    "tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)\n",
    "X_tab_train = tab_preprocessor.fit_transform(df_train)\n",
    "X_tab_valid = tab_preprocessor.transform(df_valid)\n",
    "X_tab_test = tab_preprocessor.transform(df_test)\n",
    "\n",
    "# target\n",
    "y_train = df_train[\"class\"].values\n",
    "y_valid = df_valid[\"class\"].values\n",
    "y_test = df_test[\"class\"].values\n",
    "\n",
    "X_train = {\"X_tab\": X_tab_train, \"target\": y_train}\n",
    "X_val = {\"X_tab\": X_tab_valid, \"target\": y_valid}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39d37f8f",
   "metadata": {},
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1928f7f2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (deeptabular): Sequential(\n",
       "    (0): TabMlp(\n",
       "      (cont_norm): Identity()\n",
       "      (encoder): MLP(\n",
       "        (mlp): Sequential(\n",
       "          (dense_layer_0): Sequential(\n",
       "            (0): Linear(in_features=7, out_features=200, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (dense_layer_1): Sequential(\n",
       "            (0): Linear(in_features=200, out_features=100, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=100, out_features=5, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deeptabular = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    continuous_cols=tab_preprocessor.continuous_cols,\n",
    ")\n",
    "model = WideDeep(deeptabular=deeptabular, pred_dim=df_enc[\"class\"].nunique())\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "20b0881e",
   "metadata": {},
   "outputs": [],
   "source": [
    "auroc = AUROC(num_classes=df_enc[\"class\"].nunique(), task=\"multiclass\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3820cfa9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████| 6/6 [00:00<00:00, 54.59it/s, loss=0.109, metrics={'MulticlassAUROC': 0.314}]\n",
      "valid: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 98.35it/s, loss=0.105, metrics={'MulticlassAUROC': 0.2558}]\n",
      "epoch 2: 100%|██████████████████████████████████████████████████| 6/6 [00:00<00:00, 91.55it/s, loss=0.105, metrics={'MulticlassAUROC': 0.3546}]\n",
      "valid: 100%|███████████████████████████████████████████████████| 1/1 [00:00<00:00, 111.68it/s, loss=0.101, metrics={'MulticlassAUROC': 0.2737}]\n",
      "epoch 3: 100%|████████████████████████████████████████████████████| 6/6 [00:00<00:00, 62.55it/s, loss=0.1, metrics={'MulticlassAUROC': 0.3795}]\n",
      "valid: 100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00, 108.51it/s, loss=0.0966, metrics={'MulticlassAUROC': 0.3053}]\n",
      "epoch 4: 100%|█████████████████████████████████████████████████| 6/6 [00:00<00:00, 99.35it/s, loss=0.0965, metrics={'MulticlassAUROC': 0.3809}]\n",
      "valid: 100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00, 117.73it/s, loss=0.0962, metrics={'MulticlassAUROC': 0.3089}]\n",
      "epoch 5: 100%|████████████████████████████████████████████████| 6/6 [00:00<00:00, 110.56it/s, loss=0.0967, metrics={'MulticlassAUROC': 0.3509}]\n",
      "valid: 100%|██████████████████████████████████████████████████| 1/1 [00:00<00:00, 127.35it/s, loss=0.0958, metrics={'MulticlassAUROC': 0.3089}]\n"
     ]
    }
   ],
   "source": [
    "# Optimizers\n",
    "deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)\n",
    "# LR Scheduler\n",
    "deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)\n",
    "# Hyperparameters\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    objective=\"multiclass_focal_loss\",\n",
    "    lr_schedulers={\"deeptabular\": deep_sch},\n",
    "    initializers={\"deeptabular\": XavierNormal},\n",
    "    optimizers={\"deeptabular\": deep_opt},\n",
    "    metrics=[auroc],\n",
    ")\n",
    "\n",
    "trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3e62133-0707-4c2e-b411-7479ba967e04",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
